# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convert coco data to mindrecord format."""

import os
import numpy as np
from pycocotools.coco import COCO
from mindspore.mindrecord import FileWriter
from mindvision.common.utils.class_factory import ClassFactory, ModuleType
from datasets.utils.classes import get_classes

@ClassFactory.register(ModuleType.DATASET)
class Ssd2MindRecord:
    """
    convert coco dataset to mind record file.

    Args:
        data_root (str) : The image dir.
        ann_file (str) : The ann file dir.
        mindrecord_file (str) : The output mindrecord file
    """

    COCO_CLASSES = list(get_classes(label="COCO"))
    COCO_CLASSES.insert(0, "background")

    def __init__(self, root, ann_file, mindrecord_dir, is_training,
                 remove_images_without_annos=True,
                 filter_crowd_anno=True):
        """Constructor for Coco2MindRecord"""
        self.data_root = root
        self.ann_file = ann_file
        self.mindrecord_dir = mindrecord_dir
        self.remove_images_without_annos = remove_images_without_annos
        self.filter_crowd_anno = filter_crowd_anno
        self.is_training = is_training

    def create_coco_label_ssd(self):
        """Get image path and annotation from COCO."""
        coco = COCO(self.ann_file)

        # Classes need to train or test.
        train_cls = self.COCO_CLASSES
        train_cls_dict = {}
        for i, cls in enumerate(train_cls):
            train_cls_dict[cls] = i

        classs_dict = {}
        cat_ids = coco.loadCats(coco.getCatIds())
        for cat in cat_ids:
            classs_dict[cat["id"]] = cat["name"]

        image_ids = coco.getImgIds()
        images = []
        image_path_dict = {}
        image_anno_dict = {}

        for img_id in image_ids:
            image_info = coco.loadImgs(img_id)
            file_name = image_info[0]["file_name"]
            anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
            anno = coco.loadAnns(anno_ids)
            image_path = os.path.join(self.data_root, file_name)
            annos = []
            iscrowd = False
            for label in anno:
                bbox = label["bbox"]
                class_name = classs_dict[label["category_id"]]
                iscrowd = iscrowd or label["iscrowd"]
                if class_name in train_cls:
                    x_min, x_max = bbox[0], bbox[0] + bbox[2]
                    y_min, y_max = bbox[1], bbox[1] + bbox[3]
                    annos.append(list(map(round, [y_min, x_min, y_max, x_max])) + [train_cls_dict[class_name]])

            if not self.is_training and iscrowd:
                continue
            if len(annos) >= 1:
                images.append(img_id)
                image_path_dict[img_id] = image_path
                image_anno_dict[img_id] = np.array(annos)

        return images, image_path_dict, image_anno_dict


    def train_data_to_mindrecord_byte_image(self, prefix="coco.mindrecord", file_num=8):
        """Create MindRecord file."""

        mindrecord_dir = self.mindrecord_dir
        mindrecord_path = os.path.join(mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        images, image_path_dict, image_anno_dict = self.create_coco_label_ssd()
        print("create mindrecord for ssd in mindrecord.py train_data_to_mindrecord_byte_image")
        ssd_json = {
            "img_id": {"type": "int32", "shape": [1]},
            "image": {"type": "bytes"},
            "annotation": {"type": "int32", "shape": [-1, 5]},
        }

        writer.add_schema(ssd_json, "ssd_json")

        for img_id in images:
            image_path = image_path_dict[img_id]
            with open(image_path, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[img_id], dtype=np.int32)
            img_id = np.array([img_id], dtype=np.int32)
            row = {"img_id": img_id, "image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def eval_data_to_mindrecord_byte_image(self, prefix="coco.mindrecord", file_num=1):
        """Create MindRecord file."""

        mindrecord_dir = self.mindrecord_dir
        mindrecord_path = os.path.join(mindrecord_dir, prefix)
        writer = FileWriter(mindrecord_path, file_num)
        images, image_path_dict, image_anno_dict = self.create_coco_label_ssd()
        print("create mindrecord for ssd in mindrecord.py eval_data_to_mindrecord_byte_image")
        ssd_json = {
            "img_id": {"type": "int32", "shape": [1]},
            "image": {"type": "bytes"},
            "annotation": {"type": "int32", "shape": [-1, 5]},
        }

        writer.add_schema(ssd_json, "ssd_json")

        for img_id in images:
            image_path = image_path_dict[img_id]
            with open(image_path, 'rb') as f:
                img = f.read()
            annos = np.array(image_anno_dict[img_id], dtype=np.int32)
            img_id = np.array([img_id], dtype=np.int32)
            row = {"img_id": img_id, "image": img, "annotation": annos}
            writer.write_raw_data([row])
        writer.commit()

    def __call__(self, prefix='coco.mindrecord'):
        """ Write mindrecord file """
        mindrecord_path = os.path.join(self.mindrecord_dir, prefix)

        if self.is_training:
            if os.path.exists(mindrecord_path + "0"):
                print("coco.mindrecord files already exist,not get in train_data_to_mindrecord_byte_image")
                return
            self.train_data_to_mindrecord_byte_image()
        else:
            if os.path.exists(mindrecord_path):
                print("coco.mindrecord files already exist,not get in eval_data_to_mindrecord_byte_image")
                return
            self.eval_data_to_mindrecord_byte_image()
